vdkch/main.py
2025-02-17 14:05:25 +03:00

365 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# main.py
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import List, Optional
import uuid
import os
from datetime import datetime
import requests
from qdrant_client import QdrantClient
from qdrant_client.http import models
from fastapi.middleware.cors import CORSMiddleware
import time # Добавлен импорт модуля time
# Database setup
from sqlalchemy import create_engine, Column, String, DateTime, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
# Image processing
from PIL import Image
import io
import torch
from transformers import CLIPProcessor, CLIPModel
import logging
from qdrant_client.http.models import VectorParams, Distance
import os
from dotenv import load_dotenv
# Настройка базового логирования
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Загрузить переменные окружения из файла .env
load_dotenv()
# Получение значений с указанием значений по умолчанию
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "posts")
VECTOR_SIZE = int(os.getenv("VECTOR_SIZE", 1280))
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./imageboard.db")
QDRANT_URL = os.getenv("QDRANT_URL", "http://qdrant:6333")
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://ollama:11435")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nomic-embed-text")
IMAGE_MODEL = os.getenv("IMAGE_MODEL", "openai/clip-vit-base-patch32")
# IMAGE_SIZE ожидается в формате "224,224", преобразуем его в кортеж чисел
image_size_str = os.getenv("IMAGE_SIZE", "224,224")
IMAGE_SIZE = tuple(map(int, image_size_str.split(',')))
UPLOAD_DIR = os.getenv("UPLOAD_DIR", "uploads")
os.makedirs(UPLOAD_DIR, exist_ok=True)
# Инициализация компонентов
Base = declarative_base()
engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Инициализация CLIP
clip_model = CLIPModel.from_pretrained(IMAGE_MODEL)
clip_processor = CLIPProcessor.from_pretrained(IMAGE_MODEL)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model = clip_model.to(device)
# Функция для создания QdrantClient с повторными попытками
def create_qdrant_client():
max_attempts = 5
attempt = 0
while attempt < max_attempts:
try:
client = QdrantClient(QDRANT_URL)
# Проверка подключения
client.get_collections()
logger.info("Успешное подключение к Qdrant")
return client
except Exception as e:
logger.warning(f"Попытка {attempt+1} подключения к Qdrant не удалась: {str(e)}")
attempt += 1
time.sleep(2)
raise RuntimeError(f"Не удалось подключиться к Qdrant после {max_attempts} попыток")
# Инициализация клиента Qdrant
try:
qdrant_client = create_qdrant_client()
except Exception as e:
logger.error(f"Ошибка инициализации Qdrant: {str(e)}")
raise
# Функция проверки и создания коллекции
def ensure_collection_exists():
max_attempts = 5
attempt = 0
while attempt < max_attempts:
try:
# Проверка существования коллекции
qdrant_client.get_collection(collection_name=COLLECTION_NAME)
logger.info(f"Коллекция '{COLLECTION_NAME}' существует")
return
except Exception as e:
if "not found" in str(e).lower():
logger.info(f"Создание коллекции '{COLLECTION_NAME}'...")
try:
qdrant_client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(
size=VECTOR_SIZE,
distance=Distance.COSINE
)
)
logger.info(f"Коллекция '{COLLECTION_NAME}' создана")
return
except Exception as create_error:
logger.error(f"Ошибка создания: {str(create_error)}")
else:
logger.error(f"Ошибка подключения: {str(e)}")
attempt += 1
time.sleep(2)
raise RuntimeError(f"Не удалось инициализировать коллекцию после {max_attempts} попыток")
# Вызов функции проверки коллекции
try:
ensure_collection_exists()
except Exception as e:
logger.error(f"Ошибка инициализации коллекции: {str(e)}")
raise
# Инициализация FastAPI
app = FastAPI()
app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name="uploads")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Модель базы данных
class Post(Base):
__tablename__ = "posts"
id = Column(String, primary_key=True, index=True)
text = Column(Text, nullable=True)
image = Column(String, nullable=True)
created_at = Column(DateTime)
Base.metadata.create_all(bind=engine)
# Pydantic model для ответа
class PostResponse(BaseModel):
id: str
text: Optional[str] = None
image: Optional[str] = None
created_at: datetime
vector: Optional[List[float]] = None
class Config:
orm_mode = True
# Pydantic модель для запроса вектора
class VectorQuery(BaseModel):
vector: List[float]
# Utility functions
def generate_text_embedding(text: str) -> List[float]:
response = requests.post(
f"{OLLAMA_URL}/api/embeddings",
json={"model": EMBEDDING_MODEL, "prompt": text}
)
if response.status_code != 200:
raise HTTPException(status_code=500, detail="Embedding generation failed")
return response.json()["embedding"]
def generate_image_embedding(image_bytes: bytes) -> List[float]:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = clip_processor(
images=image,
return_tensors="pt",
padding=True
).to(device)
with torch.no_grad():
features = clip_model.get_image_features(**inputs)
return features.cpu().numpy().tolist()[0]
def process_image(image_bytes: bytes) -> bytes:
img = Image.open(io.BytesIO(image_bytes))
img = img.convert("RGB")
img = img.resize(IMAGE_SIZE)
buffer = io.BytesIO()
img.save(buffer, format="JPEG")
return buffer.getvalue()
# API endpoints
# API endpoints
@app.post("/posts/", response_model=PostResponse)
async def create_post(
text: Optional[str] = Form(None),
image: Optional[UploadFile] = File(None)
):
db = SessionLocal()
try:
post_id = str(uuid.uuid4())
image_path = None
thumbnail_path = None # Placeholder for thumbnail
embeddings = []
if text:
logger.info("Генерация эмбеддинга для текста")
text_embedding = generate_text_embedding(text)
embeddings.extend(text_embedding)
if image:
logger.info("Обработка изображения")
image_bytes = await image.read()
# Save original image
image_path = f"{UPLOAD_DIR}/{post_id}.jpg"
with open(image_path, "wb") as f:
f.write(image_bytes)
# Create processed image for embeddings
processed_image = process_image(image_bytes) # Assume this resizes/image processing
# Generate thumbnail as placeholder (example implementation)
#thumbnail = generate_thumbnail(image_bytes) # Implement your thumbnail generation
#thumbnail_path = f"{UPLOAD_DIR}/{post_id}_thumbnail.jpg"
#with open(thumbnail_path, "wb") as f:
# f.write(thumbnail)
# Generate embeddings from processed image
image_embedding = generate_image_embedding(processed_image)
embeddings.extend(image_embedding)
logger.info("Сохранение данных в Qdrant")
qdrant_client.upsert(
collection_name="posts",
points=[models.PointStruct(
id=post_id,
vector=embeddings,
payload={"post_id": post_id}
)]
)
logger.info("Сохранение поста в базу данных")
db_post = Post(
id=post_id,
text=text,
image=image_path,
#thumbnail=thumbnail_path, Add thumbnail field to your Post model
created_at=datetime.now()
)
db.add(db_post)
db.commit()
db.refresh(db_post)
response = PostResponse(
id=db_post.id,
text=db_post.text,
image=db_post.image,
#thumbnail=db_post.thumbnail, Update response model to include thumbnail
created_at=db_post.created_at,
vector=embeddings
)
logger.info("Пост успешно создан: %s", response)
return response
except Exception as e:
db.rollback()
logger.exception("Ошибка при создании поста")
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@app.get("/search/")
async def search_posts(
text: Optional[str] = None,
image: Optional[UploadFile] = File(None)
):
try:
query_embedding = []
if text:
logger.info("Генерация эмбеддинга для текста (поиск)")
text_embedding = generate_text_embedding(text)
query_embedding.extend(text_embedding)
if image:
logger.info("Генерация эмбеддинга для изображения (поиск)")
image_bytes = await image.read()
processed_image = process_image(image_bytes)
image_embedding = generate_image_embedding(processed_image)
query_embedding.extend(image_embedding)
logger.info("Выполнение поиска в Qdrant")
search_results = qdrant_client.search(
collection_name="posts",
query_vector=query_embedding,
limit=10
)
logger.info("Поиск завершён. Найдено результатов: %d", len(search_results))
return [result.payload for result in search_results]
except Exception as e:
logger.exception("Ошибка при поиске постов")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/posts/", response_model=List[PostResponse])
async def get_all_posts():
db = SessionLocal()
try:
posts = db.query(Post).all()
return posts
finally:
db.close()
# Новый endpoint: получение "древа" постов по вектору пользователя
@app.post("/posts/tree", response_model=List[PostResponse])
async def get_posts_tree(query: VectorQuery):
# Выполняем поиск в Qdrant с большим лимитом, чтобы получить все посты, отсортированные по сходству
search_results = qdrant_client.search(
collection_name="posts",
query_vector=query.vector,
limit=10000 # Задайте лимит в зависимости от ожидаемого числа постов
)
print("search_results")
print(search_results)
# Извлекаем список ID постов в том порядке, в котором Qdrant вернул результаты (от ближайших к дальним)
post_ids = [result.payload.get("post_id") for result in search_results]
print("post_ids")
print(post_ids)
db = SessionLocal()
try:
# Получаем все посты из БД по списку ID
posts = db.query(Post).filter(Post.id.in_(post_ids)).all()
print("posts")
print(posts)
# Создаём словарь для сохранения соответствия post_id -> post
posts_dict = {post.id: post for post in posts}
print("posts_dict")
print(posts_dict)
# Восстанавливаем порядок, используя список post_ids
ordered_posts = [posts_dict[pid] for pid in post_ids if pid in posts_dict]
print("ordered_posts")
print(ordered_posts)
return ordered_posts
finally:
db.close()
if __name__ == "__main__":
import uvicorn
print("COLLECTION_NAME:", COLLECTION_NAME)
print("VECTOR_SIZE:", VECTOR_SIZE)
print("DATABASE_URL:", DATABASE_URL)
print("QDRANT_URL:", QDRANT_URL)
print("OLLAMA_URL:", OLLAMA_URL)
print("EMBEDDING_MODEL:", EMBEDDING_MODEL)
print("IMAGE_MODEL:", IMAGE_MODEL)
print("IMAGE_SIZE:", IMAGE_SIZE)
print("UPLOAD_DIR:", UPLOAD_DIR)
uvicorn.run(app, host="0.0.0.0", port=8000)