refactor: replace SQLModel with QueryModel in various models and update query methods
This commit is contained in:
@@ -10,7 +10,6 @@ from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.api.deps import require_org_admin, require_org_member
|
||||
from app.api.queryset import api_qs
|
||||
from app.core.auth import AuthContext, get_auth_context
|
||||
from app.core.time import utcnow
|
||||
from app.db import crud
|
||||
@@ -81,14 +80,10 @@ async def _require_org_member(
|
||||
organization_id: UUID,
|
||||
member_id: UUID,
|
||||
) -> OrganizationMember:
|
||||
return await (
|
||||
api_qs(OrganizationMember)
|
||||
.filter(
|
||||
col(OrganizationMember.id) == member_id,
|
||||
col(OrganizationMember.organization_id) == organization_id,
|
||||
)
|
||||
.first_or_404(session)
|
||||
)
|
||||
member = await OrganizationMember.objects.by_id(member_id).first(session)
|
||||
if member is None or member.organization_id != organization_id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
return member
|
||||
|
||||
|
||||
async def _require_org_invite(
|
||||
@@ -97,14 +92,10 @@ async def _require_org_invite(
|
||||
organization_id: UUID,
|
||||
invite_id: UUID,
|
||||
) -> OrganizationInvite:
|
||||
return await (
|
||||
api_qs(OrganizationInvite)
|
||||
.filter(
|
||||
col(OrganizationInvite.id) == invite_id,
|
||||
col(OrganizationInvite.organization_id) == organization_id,
|
||||
)
|
||||
.first_or_404(session)
|
||||
)
|
||||
invite = await OrganizationInvite.objects.by_id(invite_id).first(session)
|
||||
if invite is None or invite.organization_id != organization_id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
return invite
|
||||
|
||||
|
||||
@router.post("", response_model=OrganizationRead)
|
||||
@@ -157,7 +148,7 @@ async def list_my_organizations(
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
await get_active_membership(session, auth.user)
|
||||
db_user = await session.get(User, auth.user.id)
|
||||
db_user = await User.objects.by_id(auth.user.id).first(session)
|
||||
active_id = db_user.active_organization_id if db_user else auth.user.active_organization_id
|
||||
|
||||
statement = (
|
||||
@@ -189,7 +180,7 @@ async def set_active_org(
|
||||
member = await set_active_organization(
|
||||
session, user=auth.user, organization_id=payload.organization_id
|
||||
)
|
||||
organization = await session.get(Organization, member.organization_id)
|
||||
organization = await Organization.objects.by_id(member.organization_id).first(session)
|
||||
if organization is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
return OrganizationRead.model_validate(organization, from_attributes=True)
|
||||
@@ -293,14 +284,10 @@ async def get_my_membership(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_member),
|
||||
) -> OrganizationMemberRead:
|
||||
user = await session.get(User, ctx.member.user_id)
|
||||
access_rows = list(
|
||||
await session.exec(
|
||||
select(OrganizationBoardAccess).where(
|
||||
col(OrganizationBoardAccess.organization_member_id) == ctx.member.id
|
||||
)
|
||||
)
|
||||
)
|
||||
user = await User.objects.by_id(ctx.member.user_id).first(session)
|
||||
access_rows = await OrganizationBoardAccess.objects.filter_by(
|
||||
organization_member_id=ctx.member.id
|
||||
).all(session)
|
||||
model = _member_to_read(ctx.member, user)
|
||||
model.board_access = [
|
||||
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows
|
||||
@@ -342,14 +329,10 @@ async def get_org_member(
|
||||
)
|
||||
if not is_org_admin(ctx.member) and member.user_id != ctx.member.user_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
user = await session.get(User, member.user_id)
|
||||
access_rows = list(
|
||||
await session.exec(
|
||||
select(OrganizationBoardAccess).where(
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id
|
||||
)
|
||||
)
|
||||
)
|
||||
user = await User.objects.by_id(member.user_id).first(session)
|
||||
access_rows = await OrganizationBoardAccess.objects.filter_by(
|
||||
organization_member_id=member.id
|
||||
).all(session)
|
||||
model = _member_to_read(member, user)
|
||||
model.board_access = [
|
||||
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows
|
||||
@@ -374,7 +357,7 @@ async def update_org_member(
|
||||
updates["role"] = normalize_role(updates["role"])
|
||||
updates["updated_at"] = utcnow()
|
||||
member = await crud.patch(session, member, updates)
|
||||
user = await session.get(User, member.user_id)
|
||||
user = await User.objects.by_id(member.user_id).first(session)
|
||||
return _member_to_read(member, user)
|
||||
|
||||
|
||||
@@ -393,20 +376,19 @@ async def update_member_access(
|
||||
|
||||
board_ids = {entry.board_id for entry in payload.board_access}
|
||||
if board_ids:
|
||||
valid_board_ids = set(
|
||||
await session.exec(
|
||||
select(Board.id)
|
||||
.where(col(Board.id).in_(board_ids))
|
||||
.where(col(Board.organization_id) == ctx.organization.id)
|
||||
)
|
||||
)
|
||||
valid_board_ids = {
|
||||
board.id
|
||||
for board in await Board.objects.filter_by(organization_id=ctx.organization.id)
|
||||
.filter(col(Board.id).in_(board_ids))
|
||||
.all(session)
|
||||
}
|
||||
if valid_board_ids != board_ids:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
|
||||
await apply_member_access_update(session, member=member, update=payload)
|
||||
await session.commit()
|
||||
await session.refresh(member)
|
||||
user = await session.get(User, member.user_id)
|
||||
user = await User.objects.by_id(member.user_id).first(session)
|
||||
return _member_to_read(member, user)
|
||||
|
||||
|
||||
@@ -416,9 +398,11 @@ async def remove_org_member(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
) -> OkResponse:
|
||||
member = await session.get(OrganizationMember, member_id)
|
||||
if member is None or member.organization_id != ctx.organization.id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
member = await _require_org_member(
|
||||
session,
|
||||
organization_id=ctx.organization.id,
|
||||
member_id=member_id,
|
||||
)
|
||||
if member.user_id == ctx.member.user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
@@ -430,15 +414,12 @@ async def remove_org_member(
|
||||
detail="Only owners can remove owners",
|
||||
)
|
||||
if member.role == "owner":
|
||||
owner_ids = list(
|
||||
await session.exec(
|
||||
select(OrganizationMember.id).where(
|
||||
col(OrganizationMember.organization_id) == ctx.organization.id,
|
||||
col(OrganizationMember.role) == "owner",
|
||||
)
|
||||
)
|
||||
owners = (
|
||||
await OrganizationMember.objects.filter_by(organization_id=ctx.organization.id)
|
||||
.filter(col(OrganizationMember.role) == "owner")
|
||||
.all(session)
|
||||
)
|
||||
if len(owner_ids) <= 1:
|
||||
if len(owners) <= 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
detail="Organization must have at least one owner",
|
||||
@@ -451,17 +432,22 @@ async def remove_org_member(
|
||||
),
|
||||
)
|
||||
|
||||
user = await session.get(User, member.user_id)
|
||||
user = await User.objects.by_id(member.user_id).first(session)
|
||||
if user is not None and user.active_organization_id == ctx.organization.id:
|
||||
fallback_org_id = (
|
||||
await session.exec(
|
||||
select(OrganizationMember.organization_id)
|
||||
.where(col(OrganizationMember.user_id) == user.id)
|
||||
.where(col(OrganizationMember.organization_id) != ctx.organization.id)
|
||||
.order_by(col(OrganizationMember.created_at).asc())
|
||||
fallback_membership = (
|
||||
await OrganizationMember.objects.filter(
|
||||
col(OrganizationMember.user_id) == user.id,
|
||||
col(OrganizationMember.organization_id) != ctx.organization.id,
|
||||
)
|
||||
.order_by(col(OrganizationMember.created_at).asc())
|
||||
.first(session)
|
||||
)
|
||||
if isinstance(fallback_membership, UUID):
|
||||
user.active_organization_id = fallback_membership
|
||||
else:
|
||||
user.active_organization_id = (
|
||||
fallback_membership.organization_id if fallback_membership is not None else None
|
||||
)
|
||||
).first()
|
||||
user.active_organization_id = fallback_org_id
|
||||
session.add(user)
|
||||
|
||||
await crud.delete(session, member)
|
||||
@@ -474,8 +460,7 @@ async def list_org_invites(
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
) -> DefaultLimitOffsetPage[OrganizationInviteRead]:
|
||||
statement = (
|
||||
api_qs(OrganizationInvite)
|
||||
.filter(col(OrganizationInvite.organization_id) == ctx.organization.id)
|
||||
OrganizationInvite.objects.filter_by(organization_id=ctx.organization.id)
|
||||
.filter(col(OrganizationInvite.accepted_at).is_(None))
|
||||
.order_by(col(OrganizationInvite.created_at).desc())
|
||||
.statement
|
||||
@@ -522,13 +507,12 @@ async def create_org_invite(
|
||||
|
||||
board_ids = {entry.board_id for entry in payload.board_access}
|
||||
if board_ids:
|
||||
valid_board_ids = set(
|
||||
await session.exec(
|
||||
select(Board.id)
|
||||
.where(col(Board.id).in_(board_ids))
|
||||
.where(col(Board.organization_id) == ctx.organization.id)
|
||||
)
|
||||
)
|
||||
valid_board_ids = {
|
||||
board.id
|
||||
for board in await Board.objects.filter_by(organization_id=ctx.organization.id)
|
||||
.filter(col(Board.id).in_(board_ids))
|
||||
.all(session)
|
||||
}
|
||||
if valid_board_ids != board_ids:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
await apply_invite_board_access(session, invite=invite, entries=payload.board_access)
|
||||
@@ -566,13 +550,10 @@ async def accept_org_invite(
|
||||
) -> OrganizationMemberRead:
|
||||
if auth.user is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
invite = (
|
||||
await session.exec(
|
||||
select(OrganizationInvite)
|
||||
.where(col(OrganizationInvite.token) == payload.token)
|
||||
.where(col(OrganizationInvite.accepted_at).is_(None))
|
||||
)
|
||||
).first()
|
||||
invite = await OrganizationInvite.objects.filter(
|
||||
col(OrganizationInvite.token) == payload.token,
|
||||
col(OrganizationInvite.accepted_at).is_(None),
|
||||
).first(session)
|
||||
if invite is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
if invite.invited_email and auth.user.email:
|
||||
@@ -597,5 +578,5 @@ async def accept_org_invite(
|
||||
await session.commit()
|
||||
member = existing
|
||||
|
||||
user = await session.get(User, member.user_id)
|
||||
user = await User.objects.by_id(member.user_id).first(session)
|
||||
return _member_to_read(member, user)
|
||||
|
||||
Reference in New Issue
Block a user